import numpy as np
import os
import emcee
from multiprocessing import Pool, cpu_count
from scipy.stats.qmc import Sobol, scale
import pandas as pd

from jm.library.likelihood import Likelihood, Constraint
from jm.library.additional_helpers import get_cols_over_time
import warnings
warnings.filterwarnings("ignore")


subsample = False
subsample_size = None
if os.environ.get("test_size")=="true":
    subsample = True

if subsample is True:
    subsample_size = 20
    n_burn = 10
    max_n = 10
else:
    n_burn = 1000
    max_n = 4000

for type in ['base', 'timetrend', 'valor_regime', 'logistic',
             'relaxedbound', 'Q1', 'voluntary',
             'interaction_with_G1_deadline', 'multicov_shareby3_age',
             'pre_post_july', 'pre_post_june',
             'with_calls_data', 'truncate_calls', 'multicov_shareby3_age_missing']:
    
    np.random.seed(0)

    timetrend = False
    valor_regime = False
    logistic = False
    relaxedbound = False
    Q1 = False
    voluntary = False
    interaction_with_G1_deadline = False
    multicov_shareby3_age = False
    multicov_shareby3_age_missing = False
    pre_post_july = False
    pre_post_june = False
    with_calls_data = False
    truncate_calls = False

    if type=='timetrend':
        timetrend = True
    elif type=='valor_regime':
        valor_regime = True
    elif type=='logistic':
        logistic = True
    elif type=='relaxedbound':
        relaxedbound = True
    elif type=='Q1':
        Q1 = True
    elif type=='voluntary':
        voluntary = True
    elif type=='interaction_with_G1_deadline':
        interaction_with_G1_deadline = True
    elif type=='multicov_shareby3_age':
        multicov_shareby3_age = True
    elif type=='pre_post_july':
        pre_post_july = True
    elif type=='pre_post_june':
        pre_post_june = True
    elif type=='with_calls_data':
        with_calls_data = True
    elif type=='truncate_calls':
        with_calls_data = True
        truncate_calls = True
    elif type=='multicov_shareby3_age_missing':
        multicov_shareby3_age_missing = True

    if voluntary:
        file_ending = "voluntary"
    elif timetrend:
        file_ending = "timetrend"
    elif valor_regime:
        file_ending = "valorregime"
    elif logistic:
        file_ending = "logistic"
    elif relaxedbound:
        file_ending = "relaxedbound"
    elif Q1:
        file_ending = "Q1"
    elif interaction_with_G1_deadline:
        file_ending = "G1deadlineinteraction"
    elif multicov_shareby3_age:
        file_ending = 'multicov_shareby3_age'
    elif pre_post_july:
        file_ending = 'pre_post_july'
    elif pre_post_june:
        file_ending = 'pre_post_june'
    elif with_calls_data:
        if truncate_calls:
            file_ending = 'with_calls_data_truncated500'
        else:
            file_ending = 'with_calls_data'
    elif multicov_shareby3_age_missing:
        file_ending = 'multicov_shareby3_age_missing'
    else:
        file_ending = "G1actioninteraction"

    if (timetrend + valor_regime + logistic + voluntary +
        relaxedbound + Q1 +
        interaction_with_G1_deadline +
        pre_post_july + pre_post_june + multicov_shareby3_age +
        with_calls_data + multicov_shareby3_age_missing) > 1:
        raise Exception("Only one variation at a time!")

    df_status = pd.read_csv('data/jesus_maria_status_deidentified.csv')

    if subsample:
        if subsample_size is None:
            df_status = df_status.sample(n=np.floor(2*len(df_status)/3), replace=False)
        else:
            df_status = df_status.sample(n=subsample_size, replace=False)

    df_status['above_median_G1_deadline'] = df_status['days_from_G1_to_promise'] > df_status['days_from_G1_to_promise'].median()
    if interaction_with_G1_deadline:
        covariate = ['prob_repayment_endo_covariates', 'above_median_G1_deadline']
    elif multicov_shareby3_age:
        covariate = ['prob_repayment_endo_covariates', 'last_year_share_repaid_by_3', 'quantile_age']
    elif with_calls_data:
        covariate = 'prob_repayment_endo_covariates'
        if truncate_calls:
            df_status[get_cols_over_time(col_fmt='total_calls_by_{}')] = np.minimum(df_status[get_cols_over_time(col_fmt='total_calls_by_{}')], 500)
    elif multicov_shareby3_age_missing:
        covariate = ['prob_repayment_endo_covariates', 'last_year_share_repaid_by_3',
                     'quantile_age', 'is_age_imputed', 'is_last_year_share_repaid_by_3_imputed']
    else:
        covariate = 'prob_repayment_endo_covariates'

    if logistic:
        transformed_bounds = np.array([
            (-3, 3),  # payment > 0
            (-3, 3),  # payment
            (-3, 3),  # priority G1
            (-3, 3),  # priority G2
            (-3, 3),  # priority G3
            (0, 3),  # action valor
            (-3, 3),  # action rec1
            (-3, 3),  # action medida
            (-10, 10),  # covariate
            (-3, 3),  # G1 rec1
            (-3, 3),  # G1 medida
            (-5, 5),  # a
            (0, 5),  # b
            (0, 5)  # sigma
        ])
    elif relaxedbound:
        transformed_bounds = np.array([
            (-.1, .1),  # payment > 0
            (-.1, .1),  # payment
            (-.1, .1),  # priority G1
            (-.1, .1),  # priority G2
            (-.1, .1),  # priority G3
            (-.1, .1),  # action valor
            (-.1, .1),  # action rec1
            (-.1, .1),  # action medida
            (-.5, .5),  # covariate
            (-.1, .1),  # G1 rec1
            (-.1, .1),  # G1 medida
            (-0.2, 0.2),  # a
            (0.1, 0.5),  # b = a + b'
            (0, 1)  # sigma
        ])
    elif interaction_with_G1_deadline:
        transformed_bounds = np.array([
            (-.1, .1),  # payment > 0
            (-.1, .1),  # payment
            (-.1, .1),  # priority G1
            (-.1, .1),  # priority G2
            (-.1, .1),  # priority G3
            (0, .1),  # action valor
            (-.1, .1),  # action rec1
            (-.1, .1),  # action medida
            (-.5, .5),  # covariate
            (-.5, .5),  # covariate
            (-.1, .1),  # G1 rec1
            (-.1, .1),  # G1 medida
            (-1, 1),  # 
            (-1, 1),  # 
            (-0.2, 0.2),  # a
            (0.1, 0.5),  # b = a + b'
            (0, 1)  # sigma
        ])
    elif multicov_shareby3_age:
        transformed_bounds = np.array([
            (-.1, .1),  # payment > 0
            (-.1, .1),  # payment
            (-.1, .1),  # priority G1
            (-.1, .1),  # priority G2
            (-.1, .1),  # priority G3
            (0, .1),  # action valor
            (-.1, .1),  # action rec1
            (-.1, .1),  # action medida
            (-.5, .5),  # covariate
            (-.5, .5),  # covariate
            (-.5, .5),  # covariate
            (-.1, .1),  # G1 rec1
            (-.1, .1),  # G1 medida
            (-0.2, 0.2),  # a
            (0.1, 0.5),  # b = a + b'
            (0, 1)  # sigma
        ])
    elif multicov_shareby3_age_missing:
        transformed_bounds = np.array([
            (-.1, .1),  # payment > 0
            (-.1, .1),  # payment
            (-.1, .1),  # priority G1
            (-.1, .1),  # priority G2
            (-.1, .1),  # priority G3
            (0, .1),  # action valor
            (-.1, .1),  # action rec1
            (-.1, .1),  # action medida
            (-.5, .5),  # covariate
            (-.5, .5),  # covariate
            (-.5, .5),  # covariate
            (-.5, .5),  # covariate
            (-.5, .5),  # covariate
            (-.1, .1),  # G1 rec1
            (-.1, .1),  # G1 medida
            (-0.2, 0.2),  # a
            (0.1, 0.5),  # b = a + b'
            (0, 1)  # sigma
        ])
    elif pre_post_july or pre_post_june:
        transformed_bounds = np.array([
            (-.1, .1),  # payment > 0
            (-.1, .1),  # payment
            (-.1, .1),  # priority G1
            (-.1, .1),  # priority G2
            (-.1, .1),  # priority G3
            (0, .1),  # action valor
            (-.1, .1),  # action rec1
            (-.1, .1),  # action medida
            (-.5, .5),  # covariate
            (-.1, .1),  # G1 rec1
            (-.1, .1),  # G1 medida
            (-.1, .1),  # post july G1
            (-.1, .1),  # post july writ
            (-.1, .1),  # post july G1 writ
            (-0.2, 0.2),  # a
            (0.1, 0.5),  # b = a + b'
            (0, 1)  # sigma
        ])
    elif with_calls_data:
        transformed_bounds = np.array([
            (-.1, .1),  # payment > 0
            (-.1, .1),  # payment
            (-.1, .1),  # priority G1
            (-.1, .1),  # priority G2
            (-.1, .1),  # priority G3
            (0, .1),  # action valor
            (-.1, .1),  # action rec1
            (-.1, .1),  # action medida
            (-.5, .5),  # covariate
            (-.1, .1),  # calls
            (-.1, .1),  # G1 rec1
            (-.1, .1),  # G1 medida
            (-0.2, 0.2),  # a
            (0.1, 0.5),  # b = a + b'
            (0, 1)  # sigma
        ])
    else:
        transformed_bounds = np.array([
            (-.1, .1),  # payment > 0
            (-.1, .1),  # payment
            (-.1, .1),  # priority G1
            (-.1, .1),  # priority G2
            (-.1, .1),  # priority G3
            (0, .1),  # action valor
            (-.1, .1),  # action rec1
            (-.1, .1),  # action medida
            (-.5, .5),  # covariate
            (-.1, .1),  # G1 rec1
            (-.1, .1),  # G1 medida
            (-0.2, 0.2),  # a
            (0.1, 0.5),  # b = a + b'
            (0, 1)  # sigma
        ])

    if timetrend:
        transformed_bounds = np.concatenate((transformed_bounds[0:11], [(-.1, .1)], transformed_bounds[11:]))
    elif valor_regime:
        transformed_bounds = np.concatenate((transformed_bounds[0:5], [(-.1, .1)],
                                             transformed_bounds[6:11],  [(-.1, .1)],
                                             transformed_bounds[11:]))


    def param_transform_a_b(params):
        params[-2] += params[-3]
        return params


    if logistic:
        options = {'sshape': 'logistic'}
    else:
        options = {'param_transform': param_transform_a_b}

    constraint = Constraint(transformed_bounds)
    likelihood = Likelihood(
        df_status, constraint, num_types=20,
        covariate=covariate,
        options=options, timetrend=timetrend,
        valor_regime=valor_regime,
        interaction_with_G1_deadline=interaction_with_G1_deadline,
        pre_post_july=pre_post_july,
        pre_post_june = pre_post_june,
        multicov_shareby3_age=multicov_shareby3_age,
        multicov_shareby3_age_missing=multicov_shareby3_age_missing,
        with_calls_data=with_calls_data)

    ndim, nwalkers = likelihood.num_params, 128

    def log_likelihood(x):
        return 13432 * 22 * likelihood.log_likelihood_at_param(x)

    sampler = Sobol(likelihood.num_params, scramble=True, seed=2)
    initial_values_burn = sampler.random_base2(7)
    initial_values_burn = scale(initial_values_burn, l_bounds=transformed_bounds[:, 0],
                                u_bounds=transformed_bounds[:, 1])

    if os.path.exists("estimation/parameter_estimates/jesus_maria_mcmc_" + file_ending + "_burn.h5"):
        os.remove("estimation/parameter_estimates/jesus_maria_mcmc_" + file_ending + "_burn.h5")

    burnin_backend = emcee.backends.HDFBackend(
        "estimation/parameter_estimates/jesus_maria_mcmc_" + file_ending + "_burn.h5")
    burnin_backend.reset(nwalkers=nwalkers, ndim=ndim)
    burn_new = True
    with Pool(cpu_count() - 1) as pool:
        sampler = emcee.EnsembleSampler(
            nwalkers, ndim, log_likelihood, pool=pool, backend=burnin_backend)
        state = sampler.run_mcmc(initial_values_burn, n_burn, progress=True)

    df_x = pd.DataFrame(state.coords, columns=likelihood.PARAM_NAMES)
    df_x['log_likelihood'] = state.log_prob
    df_x = df_x.sort_values('log_likelihood', ascending=False)

    initial_values = df_x.values[:, :-1]

    if os.path.exists("estimation/parameter_estimates/jesus_maria_mcmc_" + file_ending + ".h5"):
        os.remove("estimation/parameter_estimates/jesus_maria_mcmc_" + file_ending + ".h5")

    new_backend = emcee.backends.HDFBackend(
        "estimation/parameter_estimates/jesus_maria_mcmc_" + file_ending + ".h5")
    new_backend.reset(nwalkers=nwalkers, ndim=ndim)
    main_new = True
    with Pool(cpu_count() - 1) as pool:
        sampler = emcee.EnsembleSampler(
            nwalkers, ndim, log_likelihood, backend=new_backend, pool=pool)
        sampler.run_mcmc(initial_state=initial_values, nsteps=max_n, progress=True)

    os.remove("estimation/parameter_estimates/jesus_maria_mcmc_" + file_ending + "_burn.h5")
